from fastapi import Body, Request
from fastapi.responses import StreamingResponse
import asyncio, json
from typing import List, AsyncIterable

from langchain.chat_models import ChatOpenAI
from langchain import LLMChain
from langchain.callbacks import AsyncIteratorCallbackHandler
from langchain.prompts.chat import ChatPromptTemplate


from dev_opsgpt.chat.utils import History, wrap_done
from configs.model_config import (llm_model_dict, LLM_MODEL, VECTOR_SEARCH_TOP_K, SCORE_THRESHOLD)
from dev_opsgpt.utils import BaseResponse
from loguru import logger






def getChatModel(callBack: AsyncIteratorCallbackHandler = None):
    if callBack is None:
        model = ChatOpenAI(
            streaming=True,
            verbose=True,
            openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
            openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
            model_name=LLM_MODEL
        )
    else:
        model = ChatOpenAI(
            streaming=True,
            verbose=True,
            callBack=[callBack],
            openai_api_key=llm_model_dict[LLM_MODEL]["api_key"],
            openai_api_base=llm_model_dict[LLM_MODEL]["api_base_url"],
            model_name=LLM_MODEL
        )
    return model


class Chat:
    def __init__(
            self,
            engine_name: str = "",
            top_k: int = 1,
            stream: bool = False,
            ) -> None:
        self.engine_name = engine_name
        self.top_k = top_k
        self.stream = stream

    def check_service_status(self, ) -> BaseResponse:
        return BaseResponse(code=200, msg=f"okok")

    def chat(
            self, 
            query: str = Body(..., description="用户输入", examples=["hello"]),
            history: List[History] = Body(
                [], description="历史对话",
                examples=[[{"role": "user", "content": "我们来玩成语接龙，我先来，生龙活虎"}]]
                ),
            engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
            top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
            score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值，取值范围在0-1之间，SCORE越小，相关度越高，取到1相当于不筛选，建议设置在0.5左右", ge=0, le=1),
            stream: bool = Body(False, description="流式输出"),
            local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
            request: Request = None,
            ):
        self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
        self.top_k = top_k if isinstance(top_k, int) else top_k.default
        self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
        self.stream = stream if isinstance(stream, bool) else stream.default
        self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
        self.request = request
        return self._chat(query, history)
    
    def _chat(self, query: str, history: List[History]):
        history = [History(**h) if isinstance(h, dict) else h for h in history]
        ## check service dependcy is ok
        service_status = self.check_service_status()
        if service_status.code!=200: return service_status

        def chat_iterator(query: str, history: List[History]):
            model = getChatModel()

            result ,content = self.create_task(query, history, model)
            if self.stream:
                for token in content["text"]:
                    result["answer"] = token
                    yield json.dumps(result, ensure_ascii=False)
            else:
                for token in content["text"]:
                    result["answer"] += token
                yield json.dumps(result, ensure_ascii=False)
        
        return StreamingResponse(chat_iterator(query, history),
                                     media_type="text/event-stream")
        
    def achat(
            self, 
            query: str = Body(..., description="用户输入", examples=["hello"]),
            history: List[History] = Body(
                [], description="历史对话",
                examples=[[{"role": "user", "content": "我们来玩成语接龙，我先来，生龙活虎"}]]
                ),
            engine_name: str = Body(..., description="知识库名称", examples=["samples"]),
            top_k: int = Body(VECTOR_SEARCH_TOP_K, description="匹配向量数"),
            score_threshold: float = Body(SCORE_THRESHOLD, description="知识库匹配相关度阈值，取值范围在0-1之间，SCORE越小，相关度越高，取到1相当于不筛选，建议设置在0.5左右", ge=0, le=1),
            stream: bool = Body(False, description="流式输出"),
            local_doc_url: bool = Body(False, description="知识文件返回本地路径(true)或URL(false)"),
            request: Request = None,
            ):
        self.engine_name = engine_name if isinstance(engine_name, str) else engine_name.default
        self.top_k = top_k if isinstance(top_k, int) else top_k.default
        self.score_threshold = score_threshold if isinstance(score_threshold, float) else score_threshold.default
        self.stream = stream if isinstance(stream, bool) else stream.default
        self.local_doc_url = local_doc_url if isinstance(local_doc_url, bool) else local_doc_url.default
        self.request = request
        return self._achat(query, history)
    
    def _achat(self, query: str, history: List[History]):
        history = [History(**h) if isinstance(h, dict) else h for h in history]
        ## check service dependcy is ok
        service_status = self.check_service_status()
        if service_status.code!=200: return service_status

        async def chat_iterator(query, history):
            callback = AsyncIteratorCallbackHandler()
            model = getChatModel()

            task, result = self.create_atask(query, history, model, callback)
            if self.stream:
                for token in callback["text"]:
                    result["answer"] = token
                    yield json.dumps(result, ensure_ascii=False)
            else:
                for token in callback["text"]:
                    result["answer"] += token
                yield json.dumps(result, ensure_ascii=False)
            await task
        
        return StreamingResponse(chat_iterator(query, history),
                                     media_type="text/event-stream")
        
    def create_task(self, query: str, history: List[History], model):
        '''构建 llm 生成任务'''
        chat_prompt = ChatPromptTemplate.from_messages(
            [i.to_msg_tuple() for i in history] + [("human", "{input}")]
        )
        chain = LLMChain(prompt=chat_prompt, llm=model)
        content = chain({"input": query})
        return {"answer": "", "docs": ""}, content

    def create_atask(self, query, history, model, callback: AsyncIteratorCallbackHandler):
        chat_prompt = ChatPromptTemplate.from_messages(
            [i.to_msg_tuple() for i in history] + [("human", "{input}")]
        )
        chain = LLMChain(prompt=chat_prompt, llm=model)
        task = asyncio.create_task(wrap_done(
            chain.acall({"input": query}), callback.done
        ))
        return task, {"answer": "", "docs": ""}